# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Serialization utilities for Jax.

All Flax classes that carry state (e.g., Optimizer) can be turned into a
state dict of numpy arrays for easy serialization.
"""
import enum
import threading
from contextlib import contextmanager
from typing import Any

import jax
import msgpack
import numpy as np

_STATE_DICT_REGISTRY: dict[Any, Any] = {}


class _ErrorContext(threading.local):
  """Context for deserialization error messages."""

  def __init__(self):
    self.path = []


_error_context = _ErrorContext()


@contextmanager
def _record_path(name):
  try:
    _error_context.path.append(name)
    yield
  finally:
    _error_context.path.pop()


def current_path():
  """Current state_dict path during deserialization for error messages."""
  return '/'.join(_error_context.path)


class _NamedTuple:
  """Fake type marker for namedtuple for registry."""

  pass


def _is_namedtuple(x):
  """Duck typing test for namedtuple factory-generated objects."""
  return isinstance(x, tuple) and hasattr(x, '_fields')


def from_state_dict(target, state: dict[str, Any], name: str = '.'):
  """Restores the state of the given target using a state dict.

  This function takes the current target as an argument. This
  lets us know the exact structure of the target,
  as well as lets us add assertions that shapes and dtypes don't change.

  In practice, none of the leaf values in ``target`` are actually
  used. Only the tree structure, shapes and dtypes.

  Args:
    target: the object of which the state should be restored.
    state: a dictionary generated by ``to_state_dict`` with the desired new
           state for ``target``.
    name: name of branch taken, used to improve deserialization error messages.
  Returns:
    A copy of the object with the restored state.
  """
  if _is_namedtuple(target):
    ty = _NamedTuple
  else:
    ty = type(target)
  if ty not in _STATE_DICT_REGISTRY:
    return state
  ty_from_state_dict = _STATE_DICT_REGISTRY[ty][1]
  with _record_path(name):
    return ty_from_state_dict(target, state)


def to_state_dict(target) -> dict[str, Any]:
  """Returns a dictionary with the state of the given target."""
  if _is_namedtuple(target):
    ty = _NamedTuple
  else:
    ty = type(target)
  if ty not in _STATE_DICT_REGISTRY:
    return target

  ty_to_state_dict = _STATE_DICT_REGISTRY[ty][0]
  state_dict = ty_to_state_dict(target)
  if isinstance(state_dict, dict):
    for key in state_dict.keys():
      assert isinstance(key, str), 'A state dict must only have string keys.'
  return state_dict


def is_serializable(target):
  if not isinstance(target, type):
    target = type(target)
  return target in _STATE_DICT_REGISTRY


def register_serialization_state(
  ty, ty_to_state_dict, ty_from_state_dict, override=False
):
  """Register a type for serialization.

  Args:
    ty: the type to be registered
    ty_to_state_dict: a function that takes an instance of ty and
      returns its state as a dictionary.
    ty_from_state_dict: a function that takes an instance of ty and
      a state dict, and returns a copy of the instance with the restored state.
    override: override a previously registered serialization handler
      (default: False).
  """
  if ty in _STATE_DICT_REGISTRY and not override:
    raise ValueError(
      f'a serialization handler for "{ty.__name__}" is already registered'
    )
  _STATE_DICT_REGISTRY[ty] = (ty_to_state_dict, ty_from_state_dict)


def _list_state_dict(xs: list[Any]) -> dict[str, Any]:
  return {str(i): to_state_dict(x) for i, x in enumerate(xs)}


def _restore_list(xs, state_dict: dict[str, Any]) -> list[Any]:
  if len(state_dict) != len(xs):
    raise ValueError(
      'The size of the list and the state dict do not match,'
      f' got {len(xs)} and {len(state_dict)} '
      f'at path {current_path()}'
    )
  ys = []
  for i in range(len(state_dict)):
    y = from_state_dict(xs[i], state_dict[str(i)], name=str(i))
    ys.append(y)
  return ys


def _dict_state_dict(xs: dict[str, Any]) -> dict[str, Any]:
  str_keys = {str(k) for k in xs.keys()}
  if len(str_keys) != len(xs):
    raise ValueError(
      'Dict keys do not have a unique string representation: '
      f'{str_keys} vs given: {xs}'
    )
  return {str(key): to_state_dict(value) for key, value in xs.items()}


def _restore_dict(xs, states: dict[str, Any]) -> dict[str, Any]:
  diff = set(map(str, xs.keys())).difference(states.keys())
  if diff:
    raise ValueError(
      'The target dict keys and state dict keys do not match, target dict'
      f' contains keys {diff} which are not present in state dict at path'
      f' {current_path()}'
    )

  return {
    key: from_state_dict(value, states[str(key)], name=str(key))
    for key, value in xs.items()
  }


def _namedtuple_state_dict(nt) -> dict[str, Any]:
  return {key: to_state_dict(getattr(nt, key)) for key in nt._fields}


def _restore_namedtuple(xs, state_dict: dict[str, Any]):
  """Rebuild namedtuple from serialized dict."""
  if set(state_dict.keys()) == {'name', 'fields', 'values'}:
    # TODO(jheek): remove backward compatible named tuple restoration early 2022
    state_dict = {
      state_dict['fields'][str(i)]: state_dict['values'][str(i)]
      for i in range(len(state_dict['fields']))
    }

  sd_keys = set(state_dict.keys())
  nt_keys = set(xs._fields)

  if sd_keys != nt_keys:
    raise ValueError(
      'The field names of the state dict and the named tuple do not match,'
      f' got {sd_keys} and {nt_keys} at path {current_path()}'
    )
  fields = {
    k: from_state_dict(getattr(xs, k), v, name=k) for k, v in state_dict.items()
  }
  return type(xs)(**fields)


register_serialization_state(dict, _dict_state_dict, _restore_dict)
register_serialization_state(list, _list_state_dict, _restore_list)
register_serialization_state(
  tuple,
  _list_state_dict,
  lambda xs, state_dict: tuple(_restore_list(list(xs), state_dict)),
)
register_serialization_state(
  _NamedTuple, _namedtuple_state_dict, _restore_namedtuple
)

register_serialization_state(
  jax.tree_util.Partial,
  lambda x: (
    {
      'args': to_state_dict(x.args),
      'keywords': to_state_dict(x.keywords),
    }
  ),
  lambda x, sd: jax.tree_util.Partial(
    x.func,
    *from_state_dict(x.args, sd['args']),
    **from_state_dict(x.keywords, sd['keywords']),
  ),
)

# On-the-wire / disk serialization format

# We encode state-dicts via msgpack, using its custom type extension.
# https://github.com/msgpack/msgpack/blob/master/spec.md
#
# - ndarrays and DeviceArrays are serialized to nested msgpack-encoded string
#   of (shape-tuple, dtype-name (e.g. 'float32'), row-major array-bytes).
#   Note: only simple ndarray types are supported, no objects or fields.
#
# - native complex scalars are converted to nested msgpack-encoded tuples
#   (real, imag).


def _ndarray_to_bytes(arr) -> bytes:
  """Save ndarray to simple msgpack encoding."""
  if isinstance(arr, jax.Array):
    arr = np.array(arr)
  if arr.dtype.hasobject or arr.dtype.isalignedstruct:
    raise ValueError(
      'Object and structured dtypes not supported '
      'for serialization of ndarrays.'
    )
  tpl = (arr.shape, arr.dtype.name, arr.tobytes('C'))
  return msgpack.packb(tpl, use_bin_type=True)


def _dtype_from_name(name: str):
  """Handle JAX bfloat16 dtype correctly."""
  if name == b'bfloat16':
    return jax.numpy.bfloat16
  else:
    return np.dtype(name)


def _ndarray_from_bytes(data: bytes) -> np.ndarray:
  """Load ndarray from simple msgpack encoding."""
  shape, dtype_name, buffer = msgpack.unpackb(data, raw=True)
  return np.frombuffer(
    buffer, dtype=_dtype_from_name(dtype_name), count=-1, offset=0
  ).reshape(shape, order='C')


class _MsgpackExtType(enum.IntEnum):
  """Messagepack custom type ids."""

  ndarray = 1
  native_complex = 2
  npscalar = 3


def _msgpack_ext_pack(x):
  """Messagepack encoders for custom types."""
  # TODO(flax-dev): Array here only work when they are fully addressable.
  # If they are not fully addressable, use the GDA path for checkpointing.
  if isinstance(x, (np.ndarray, jax.Array)):
    return msgpack.ExtType(_MsgpackExtType.ndarray, _ndarray_to_bytes(x))
  if isinstance(x, np.generic):
    # pack scalar as ndarray
    return msgpack.ExtType(
      _MsgpackExtType.npscalar, _ndarray_to_bytes(np.asarray(x))
    )
  elif isinstance(x, complex):
    return msgpack.ExtType(
      _MsgpackExtType.native_complex, msgpack.packb((x.real, x.imag))
    )
  return x


def _msgpack_ext_unpack(code, data):
  """Messagepack decoders for custom types."""
  if code == _MsgpackExtType.ndarray:
    return _ndarray_from_bytes(data)
  elif code == _MsgpackExtType.native_complex:
    complex_tuple = msgpack.unpackb(data)
    return complex(complex_tuple[0], complex_tuple[1])
  elif code == _MsgpackExtType.npscalar:
    ar = _ndarray_from_bytes(data)
    return ar[()]  # unpack ndarray to scalar
  return msgpack.ExtType(code, data)


# Chunking array leaves

# msgpack has a hard limit of 2**31 - 1 bytes per object leaf.  To circumvent
# this limit for giant arrays (e.g. embedding tables), we traverse the tree
# and break up arrays near the limit into flattened array chunks.

# True limit is 2**31 - 1, but leave a margin for encoding padding.
MAX_CHUNK_SIZE = 2**30


def _np_convert_in_place(d):
  """Convert any jax devicearray leaves to numpy arrays in place."""
  if isinstance(d, dict):
    for k, v in d.items():
      if isinstance(v, jax.Array):
        d[k] = np.array(v)
      elif isinstance(v, dict):
        _np_convert_in_place(v)
  elif isinstance(d, jax.Array):
    return np.array(d)
  return d


_tuple_to_dict = lambda tpl: {str(x): y for x, y in enumerate(tpl)}
_dict_to_tuple = lambda dct: tuple(dct[str(i)] for i in range(len(dct)))


def _chunk(arr) -> dict[str, Any]:
  """Convert array to a canonical dictionary of chunked arrays."""
  chunksize = max(1, int(MAX_CHUNK_SIZE / arr.dtype.itemsize))
  data = {'__msgpack_chunked_array__': True, 'shape': _tuple_to_dict(arr.shape)}
  flatarr = arr.reshape(-1)
  chunks = [
    flatarr[i : i + chunksize] for i in range(0, flatarr.size, chunksize)
  ]
  data['chunks'] = _tuple_to_dict(chunks)
  return data


def _unchunk(data: dict[str, Any]):
  """Convert canonical dictionary of chunked arrays back into array."""
  assert '__msgpack_chunked_array__' in data
  shape = _dict_to_tuple(data['shape'])
  flatarr = np.concatenate(_dict_to_tuple(data['chunks']))
  return flatarr.reshape(shape)


def _chunk_array_leaves_in_place(d):
  """Convert oversized array leaves to safe chunked form in place."""
  if isinstance(d, dict):
    for k, v in d.items():
      if isinstance(v, np.ndarray):
        if v.size * v.dtype.itemsize > MAX_CHUNK_SIZE:
          d[k] = _chunk(v)
      elif isinstance(v, dict):
        _chunk_array_leaves_in_place(v)
  elif isinstance(d, np.ndarray):
    if d.size * d.dtype.itemsize > MAX_CHUNK_SIZE:
      return _chunk(d)
  return d


def _unchunk_array_leaves_in_place(d):
  """Convert chunked array leaves back into array leaves, in place."""
  if isinstance(d, dict):
    if '__msgpack_chunked_array__' in d:
      return _unchunk(d)
    else:
      for k, v in d.items():
        if isinstance(v, dict) and '__msgpack_chunked_array__' in v:
          d[k] = _unchunk(v)
        elif isinstance(v, dict):
          _unchunk_array_leaves_in_place(v)
  return d


# User-facing API calls:


def msgpack_serialize(pytree, in_place: bool = False) -> bytes:
  """Save data structure to bytes in msgpack format.

  Low-level function that only supports python trees with array leaves,
  for custom objects use ``to_bytes``.  It splits arrays above MAX_CHUNK_SIZE into
  multiple chunks.

  Args:
    pytree: python tree of dict, list, tuple with python primitives
      and array leaves.
    in_place: boolean specifying if pytree should be modified in place.

  Returns:
    msgpack-encoded bytes of pytree.
  """
  if not in_place:
    pytree = jax.tree_util.tree_map(lambda x: x, pytree)
  pytree = _np_convert_in_place(pytree)
  pytree = _chunk_array_leaves_in_place(pytree)
  return msgpack.packb(pytree, default=_msgpack_ext_pack, strict_types=True)


def msgpack_restore(encoded_pytree: bytes):
  """Restore data structure from bytes in msgpack format.

  Low-level function that only supports python trees with array leaves,
  for custom objects use ``from_bytes``.

  Args:
    encoded_pytree: msgpack-encoded bytes of python tree.

  Returns:
    Python tree of dict, list, tuple with python primitive
    and array leaves.
  """
  state_dict = msgpack.unpackb(
    encoded_pytree, ext_hook=_msgpack_ext_unpack, raw=False
  )
  return _unchunk_array_leaves_in_place(state_dict)


def from_bytes(target, encoded_bytes: bytes):
  """Restore optimizer or other object from msgpack-serialized state-dict.

  Args:
    target: template object with state-dict registrations that matches
      the structure being deserialized from ``encoded_bytes``.
    encoded_bytes: msgpack serialized object structurally isomorphic to
      ``target``.  Typically a flax model or optimizer.

  Returns:
    A new object structurally isomorphic to ``target`` containing the updated
    leaf data from saved data.
  """
  state_dict = msgpack_restore(encoded_bytes)
  return from_state_dict(target, state_dict)


def to_bytes(target) -> bytes:
  """Save optimizer or other object as msgpack-serialized state-dict.

  Args:
    target: template object with state-dict registrations to be
      serialized to msgpack format.  Typically a flax model or optimizer.

  Returns:
    Bytes of msgpack-encoded state-dict of ``target`` object.
  """
  state_dict = to_state_dict(target)
  return msgpack_serialize(state_dict, in_place=True)
